import requests
import sqlite3
import json
import os, time, re, sys

# for retrieving API_KEY
from dotenv import load_dotenv
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from utils.config import load_config
from utils.logger import set_logger

# Load the configuration file
config = load_config()
logger = set_logger()

# Setting the working directory to the directory of the script
os.chdir(os.path.dirname(__file__))

# Constants
START_DATETIME = config['start_datetime']  # '2023-01-01T00:00:00Z'
END_DATETIME = config['end_datetime']  # '2023-12-31T23:59:59Z'
RADIUS = config['radius']              # "7070m"
COUNTRY_NAME = config['country_name']  # e.g. "mexico"
COUNTRY_CODE = config['country_code']  # e.g. "mx"
YOUTUBE_API_URL = 'https://www.googleapis.com/youtube/v3'

# Convert the start and end datetimes to the format `YYYYMM`
START_YEAR_MONTH = re.sub(r'[^0-9]', '', START_DATETIME[:7])
END_YEAR_MONTH = re.sub(r'[^0-9]', '', END_DATETIME[:7])


# PATHS
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
INPUT_DIR = os.path.join(PROJECT_ROOT, '../../../data', COUNTRY_NAME)
OUTPUT_DIR = os.path.join(PROJECT_ROOT, '../../../output', COUNTRY_NAME)
QUERY_DB_PATH = os.path.join(OUTPUT_DIR, f'{START_YEAR_MONTH[:4]}-{START_YEAR_MONTH[4:]}-collection/{START_YEAR_MONTH}_{END_YEAR_MONTH}_01_{COUNTRY_CODE}_query_log.db')
VIDEO_DB_PATH = os.path.join(OUTPUT_DIR, f'{START_YEAR_MONTH[:4]}-{START_YEAR_MONTH[4:]}-collection/{START_YEAR_MONTH}_{END_YEAR_MONTH}_02_{COUNTRY_CODE}_youtube_data.db')


class CommentsDatabase:
    def __init__(self, video_db_path=VIDEO_DB_PATH):
        self.video_db_path = video_db_path
        self.api_key = self.get_api_key()
        self.youtube_api_url = YOUTUBE_API_URL
        
        # Establish a connection to the database
        logger.info(f"Connecting to the database: {VIDEO_DB_PATH}...")
        try:
            assert os.path.exists(self.video_db_path)
        except AssertionError:
            f"Database not found: {self.video_db_path}. Creating the database first..."
        self.conn = sqlite3.connect(VIDEO_DB_PATH)
        self.c = self.conn.cursor()

        # Query to list all tables
        self.c.execute("SELECT name FROM sqlite_master WHERE type='table';")

        # Fetch and print the table names
        tables = self.c.fetchall()
        self.n_chunks = 0
        logger.info("Tables in the database which begin with videos_chunk_:")
        for table in tables:
            if table[0].startswith("videos_chunk_"):
                logger.info(table[0])
                self.n_chunks += 1

    def get_n_chunks(self):
        return self.n_chunks

    def __del__(self):
        self.conn.close()

    def get_api_key(self):
        load_dotenv(os.path.join(os.path.dirname(__file__), '..', '.env'))
        return os.getenv('API_KEY')

    def create_table(self, comments_chunk_table):
        # Create the table
        logger.info("Creating the comments table...")
        self.c.execute(f'''
                    CREATE TABLE IF NOT EXISTS {comments_chunk_table} (
                    comment_id TEXT PRIMARY KEY,
                    video_id TEXT,
                    author_display_name TEXT,
                    author_profile_image_url TEXT,
                    author_channel_url TEXT,
                    author_channel_id TEXT,
                    channel_id TEXT,
                    text_display TEXT,
                    text_original TEXT,
                    parent_id TEXT,
                    can_rate BOOLEAN,
                    viewer_rating TEXT,
                    like_count INTEGER,
                    moderation_status TEXT,
                    published_at DATETIME,
                    updated_at DATETIME,
                    FOREIGN KEY(video_id) REFERENCES videos(video_id)
                )
                ''')
        self.conn.commit()
        logger.info(f"Comments table {comments_chunk_table} created successfully.")

    def fetch_all_pages(self, url, params):
        all_items = []
        while True:
            response = requests.get(url, params=params)
            data = response.json()
            
            # Check for errors in response
            if 'error' in data:
                error = data['error']
                logger.info(f"Error fetching data: {data['error']['message']}")
                
                # Check for rate limit exceeded error
                if error['code'] == 403 and error['errors'][0]['reason'] == 'quotaExceeded':
                    logger.info("User Rate Limit Exceeded. Pausing for 24 hours and 1 minute.")
                    time.sleep(24 * 3600 + 60) # Sleep for 24 hours and 1 minute
                    continue  # Retry the request after sleeping

                if (error['code'] == 401 or error['code'] == 403) and "authorized" in error['errors'][0]['message']:
                    logger.info("API key is not authorized. Going to the next page...")
                    # Check if there's a next page
                    nextPageToken = data.get('nextPageToken')
                    if not nextPageToken:
                        break
                    params['pageToken'] = nextPageToken  # Set the nextPageToken for the next fetch
                    continue

                # if "forbidden" in error['reason']:
                #     logger.info("Forbidden access encountered. Going to the next page...")
                #     # Check if there's a next page
                #     nextPageToken = data.get('nextPageToken')
                #     if not nextPageToken:
                #         break
                #     params['pageToken'] = nextPageToken  # Set the nextPageToken for the next fetch
                #     continue

                else:
                    # Check if there's a next page
                    logger.info("An unknown error occurred. Going to the next page...")
                    nextPageToken = data.get('nextPageToken')
                    if not nextPageToken:
                        break
                    params['pageToken'] = nextPageToken  # Set the nextPageToken for the next fetch
                    continue

            all_items.extend(data.get('items', []))

            # Check if there's a next page
            nextPageToken = data.get('nextPageToken')
            if not nextPageToken:
                break

            params['pageToken'] = nextPageToken  # Set the nextPageToken for the next fetch

            # # Optionally save each page as a JSON file
            # with open(f"response_page_{len(all_items)}.json", 'w') as f:
            #     json.dump(data, f, indent=4)

        return all_items
    
    def get_comments(self, video_id, comments_chunk_table, max_results=100):
        params = {
            'part': 'snippet,replies',
            'videoId': video_id,
            'maxResults': max_results, 
            'textFormat': 'plainText',
            'key': self.api_key
        }
        logger.info(f"Fetching all comments data for video ID: {video_id}...")
        # Fetch the comments using the YouTube API
        comments = self.fetch_all_pages(f"{self.youtube_api_url}/commentThreads", params)
        logger.info(f"Fetched {len(comments)} comments for video ID: {video_id}.")
        all_comments = []
        for item in comments:
            # Function to extract and insert comment data
            def insert_comment(comment, video_id, parent_id=None):
                comment_data = {
                    'comment_id': comment['id'],
                    'video_id': video_id,
                    'author_display_name': comment['snippet']['authorDisplayName'],
                    'author_profile_image_url': comment['snippet'].get('authorProfileImageUrl', ''),
                    'author_channel_url': comment['snippet'].get('authorChannelUrl', ''),
                    'author_channel_id': comment['snippet']['authorChannelId']['value'],
                    'channel_id': comment['snippet']['channelId'],
                    'text_display': comment['snippet']['textDisplay'],
                    'text_original': comment['snippet']['textOriginal'],
                    'parent_id': parent_id,
                    'can_rate': comment['snippet']['canRate'],
                    'viewer_rating': comment['snippet']['viewerRating'],
                    'like_count': comment['snippet']['likeCount'],
                    'moderation_status': comment['snippet'].get('moderationStatus', ''),
                    'published_at': comment['snippet']['publishedAt'],
                    'updated_at': comment['snippet']['updatedAt']
                }
                logger.info(f"Inserting comment: {comment_data['comment_id']} for video ID: {video_id}")
                # Insert or replace the comment into the database
                self.c.execute(f'''
                INSERT OR REPLACE INTO {comments_chunk_table} (
                    comment_id, video_id, author_display_name, author_profile_image_url, author_channel_url,
                    author_channel_id, channel_id, text_display, text_original, parent_id, can_rate, viewer_rating,
                    like_count, moderation_status, published_at, updated_at)
                VALUES (?,?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                ''', tuple(comment_data.values()))
                all_comments.append(comment_data)

            # Process the top-level comment
            insert_comment(item['snippet']['topLevelComment'], video_id)

            # Check and process replies if they exist
            if 'replies' in item:
                for reply in item['replies']['comments']:
                    insert_comment(reply, video_id, parent_id=item['snippet']['topLevelComment']['id'])

        self.conn.commit()
        return all_comments
    
    def process_video_comments(self, videos_chunk_table, comments_chunk_table):
        try:
            # Fetch all videos from the database
            self.c.execute(f'SELECT video_id, commentCount FROM {videos_chunk_table}')
            videos = self.c.fetchall()
            count = 0
            for video_id, comment_count in videos:
                count+=1
                if comment_count > 0:
                    logger.info(f"Processing comments for video ID: {video_id} with {comment_count} comments.")
                    # Check if the comments for this video_id have already been processed
                    self.c.execute(f'SELECT * FROM {comments_chunk_table} WHERE video_id = ?', (video_id,))
                    if self.c.fetchone() is None:
                        logger.info(f"Comments not processed for video ID: {video_id}. Fetching comments...")
                        # Fetch and insert comments into the database
                        comments = self.get_comments(video_id, comments_chunk_table)  # Assuming get_comments is defined and handles fetching and storing comments
                    else:
                        logger.info(f"Comments already processed for video ID: {video_id}")
                else:
                    logger.info(f"No comments to process for video ID: {video_id}")
                logger.info(f"\nCOMMENTS Progress: {count}/{len(videos)}.\n")
                

        except Exception as e:
            logger.error(f"An error occurred while fetching video data: {e}")
        # finally:
        #     self.conn.close()
        #     logger.info("Connection to database closed.")


  
def main():

    comments_db = CommentsDatabase()
    n_chunks = comments_db.get_n_chunks()

    print(f"Processing comments for {n_chunks} chunks of videos...")

    for i in range(1, n_chunks+1):
        logger.info(f"\nProcessing comments for videos_chunk_{i}...")
        comments_db = CommentsDatabase()
        retry_delay = 15  # seconds
        max_retries = 5  # max number of retries
        attempt = 0  # current attempt

        videos_chunk_table = f"videos_chunk_{i}"
        comments_chunk_table = f"comments_chunk_{i}"
        comments_db.create_table(comments_chunk_table)

        while attempt < max_retries:
            try:
                comments_db.process_video_comments(videos_chunk_table, comments_chunk_table)
                logger.info("Successfully processed all videos for comments.")
                break  # Exit loop if successful
            except Exception as e:
                attempt += 1
                logger.error(f"An error occurred: {e}. Retrying in {retry_delay} seconds... (Attempt {attempt}/{max_retries})")
                time.sleep(retry_delay)  # Wait before retrying
                if attempt == max_retries:
                    logger.info("Maximum retry attempts reached. Exiting.")
                    raise  # Re-raise the last exception after final attempt

if __name__ == '__main__':
    main()
